#include<cstdio>
int mod=998244353;
int n,m;
long long ans;
int min(int a,int b)
{
	if(a>b) return b;
	else return a;
} 
int main()
{
	freopen("bpmp.in","r",stdin);
	freopen("bpmp.out","w",stdout);
	scanf("%d %d",&n,&m);
	if(n==1)
	{
		printf("%d",m-1);
		return 0;
	}
	int s=min(n,m);
	int l=m-s+n;
	ans=s-1+(s)*(l-1);
	ans%=mod; 
	printf("%d",ans);

